#
# Copyright (c) 2021-2025 Semgrep Inc.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License
# version 2.1 as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the file
# LICENSE for more details.
#
import hashlib
import json
import os
import platform
import uuid
from collections import defaultdict
from datetime import datetime
from enum import auto
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import TYPE_CHECKING
from urllib.parse import urlparse

import click
import requests
from attr import define
from attr import Factory
from typing_extensions import LiteralString

import semgrep.semgrep_interfaces.semgrep_metrics as met
import semgrep.semgrep_interfaces.semgrep_output_v1 as out
from semgrep import __VERSION__
from semgrep.constants import USER_FRIENDLY_PRODUCT_NAMES
from semgrep.error import error_type_string
from semgrep.error import SemgrepError
from semgrep.mcp.models import SemgrepScanResult
from semgrep.parsing_data import ParsingData
from semgrep.profile_manager import ProfileManager
from semgrep.rule import Rule
from semgrep.semgrep_interfaces.semgrep_metrics import AnalysisType
from semgrep.semgrep_interfaces.semgrep_metrics import CodeConfig
from semgrep.semgrep_interfaces.semgrep_metrics import Datetime
from semgrep.semgrep_interfaces.semgrep_metrics import EngineConfig
from semgrep.semgrep_interfaces.semgrep_metrics import Environment
from semgrep.semgrep_interfaces.semgrep_metrics import Errors
from semgrep.semgrep_interfaces.semgrep_metrics import Extension
from semgrep.semgrep_interfaces.semgrep_metrics import FileStats
from semgrep.semgrep_interfaces.semgrep_metrics import Finding
from semgrep.semgrep_interfaces.semgrep_metrics import Interfile
from semgrep.semgrep_interfaces.semgrep_metrics import Interprocedural
from semgrep.semgrep_interfaces.semgrep_metrics import Intraprocedural
from semgrep.semgrep_interfaces.semgrep_metrics import Mcp
from semgrep.semgrep_interfaces.semgrep_metrics import ParseStat
from semgrep.semgrep_interfaces.semgrep_metrics import Payload
from semgrep.semgrep_interfaces.semgrep_metrics import Performance
from semgrep.semgrep_interfaces.semgrep_metrics import RuleStats
from semgrep.semgrep_interfaces.semgrep_metrics import SecretsConfig
from semgrep.semgrep_interfaces.semgrep_metrics import SupplyChainConfig
from semgrep.semgrep_interfaces.semgrep_metrics import Value
from semgrep.semgrep_types import get_frozen_id
from semgrep.types import FilteredMatches
from semgrep.types import TargetInfo
from semgrep.verbose_logging import getLogger

# Used below but can't be imported normally due to a circular dependency
if TYPE_CHECKING:
    from semgrep.engine import EngineType

logger = getLogger(__name__)

METRICS_ENDPOINT = "https://metrics.semgrep.dev"


class MetricsState(Enum):
    """
    Configures metrics upload.

    ON - Metrics always sent
    OFF - Metrics never sent
    AUTO - Metrics only sent if config is pulled from the server
    """

    ON = auto()
    OFF = auto()
    AUTO = auto()


class MetricsJsonEncoder(json.JSONEncoder):
    def default(self, obj: Any) -> Any:
        if isinstance(obj, datetime):
            return obj.astimezone().isoformat()

        if isinstance(obj, uuid.UUID):
            return str(obj)

        if isinstance(obj, set):
            return list(sorted(obj))

        return super().default(obj)


# to be mocked to a constant function in test_metrics.py
def mock_float(x: float) -> float:
    return x


def mock_int(x: int) -> int:
    return x


@define
class Metrics:
    """
    To prevent sending unintended metrics:
    1. send all data into this class with add_* methods
    2. ensure all add_* methods only set sanitized data

    These methods go directly from raw data to transported data,
    thereby skipping a "stored data" step,
    and enforcing that we sanitize before saving, not before sending.
    """

    _is_using_registry: bool = False
    metrics_state: MetricsState = MetricsState.OFF
    payload: Payload = Factory(
        lambda: Payload(
            environment=Environment(
                version=__VERSION__,
                configNamesHash=met.Sha256(""),
                projectHash=None,
                ci=None,
                isDiffScan=False,
                os=platform.system(),
                isTranspiledJS=False,
            ),
            errors=Errors(),
            performance=Performance(maxMemoryBytes=None),
            extension=Extension(),
            mcp=Mcp(),
            value=Value(features=[]),
            started_at=Datetime(datetime.now().astimezone().isoformat()),
            event_id=met.Uuid(str(get_frozen_id())),
            anonymous_user_id="",
            parse_rate=[],
            sent_at=Datetime(""),
        )
    )

    def __attrs_post_init__(self) -> None:
        self.payload.environment.ci = os.getenv("CI")

    def log_exception(self, func_name: str, e: Exception) -> None:
        """Log an exception at the debug level

        The add_* methods should not raise exceptions.
        We used to have a @suppress_errors wrapper for automatically catching
        exceptions but it was preventing type checking by mypy
        so we got rid of it."""
        logger.debug(f"Error in Metrics.{func_name}: {e}")

    def configure(
        self,
        metrics_state: Optional[MetricsState],
    ) -> None:
        """
        Configures whether to always, never, or automatically send metrics (based on whether config
        is pulled from the server).

        :param metrics_state: The value of the --metrics option
        :raises click.BadParameter: if both --metrics and --enable-metrics/--disable-metrics are passed
        and their values are different
        """

        self.metrics_state = metrics_state or MetricsState.AUTO

    # TODO(cooper): It would really be best if EngineType included all of the
    # information here, but I am a bit concerned about changing it, since it is
    # currently an enum. Ideally it would be more like osemgrep's Engine_type.t,
    # but that seems difficult to render here, and would seem to require
    # threading much more information through that type. Since we only really
    # care about the additional information being bundeled for metrics, we'll
    # just take some additional parameters here. Currently this is just for
    # secrets, but the same would apply for (supply chain)-related information.
    def add_engine_config(
        self,
        engineType: "EngineType",
        code: Optional[CodeConfig],
        secrets: Optional[SecretsConfig],
        supply_chain: Optional[SupplyChainConfig],
    ) -> None:
        # Needs to be imported with narrow scope due to a circular dependency if
        # it is imported at the top level.
        from semgrep.engine import EngineType

        """
        Assumes configs is list of arguments passed to semgrep using --config
        """
        try:
            self.payload.value.engineRequested = engineType.name
            analysis_type = {
                EngineType.OSS: AnalysisType(Intraprocedural()),
                EngineType.PRO_LANG: AnalysisType(Intraprocedural()),
                EngineType.PRO_INTRAFILE: AnalysisType(Interprocedural()),
                EngineType.PRO_INTERFILE: AnalysisType(Interfile()),
            }.get(engineType, AnalysisType(Intraprocedural()))
            self.payload.value.engineConfig = EngineConfig(
                analysis_type=analysis_type,
                code_config=code,
                secrets_config=secrets,
                supply_chain_config=supply_chain,
                pro_langs=True,
            )
        except Exception as e:
            self.log_exception("add_engine_config", e)

    def add_interfile_languages_used(self, used_langs: Optional[List[str]]) -> None:
        """
        Assumes configs is list of arguments passed to semgrep using --config
        """
        try:
            self.payload.value.interfileLanguagesUsed = used_langs or []
        except Exception as e:
            self.log_exception("add_interfile_languages_used", e)

    def add_is_diff_scan(self, is_diff_scan: bool) -> None:
        try:
            self.payload.environment.isDiffScan = is_diff_scan
        except Exception as e:
            self.log_exception("add_is_diff_scan", e)

    @property
    def is_using_registry(self) -> bool:
        return self._is_using_registry

    @is_using_registry.setter
    def is_using_registry(self, value: bool) -> None:
        try:
            self._is_using_registry = value
        except Exception as e:
            self.log_exception("is_using_registry", e)

    def add_project_url(self, project_url: Optional[str]) -> None:
        """
        Standardizes url then hashes
        """
        try:
            if project_url is None:
                self.payload.environment.projectHash = None
                return

            try:
                parsed_url = urlparse(project_url)
                if parsed_url.scheme == "https":
                    # Remove optional username/password from project_url
                    sanitized_url = f"{parsed_url.hostname}{parsed_url.path}"
                else:
                    # For now don't do anything special with other git-url formats
                    sanitized_url = project_url
            except ValueError:
                logger.debug(f"Failed to parse url {project_url}")
                sanitized_url = project_url

            m = hashlib.sha256(sanitized_url.encode())
            self.payload.environment.projectHash = met.Sha256(m.hexdigest())
        except Exception as e:
            self.log_exception("add_project_url", e)

    def add_configs(self, configs: Sequence[str]) -> None:
        """
        Assumes configs is list of arguments passed to semgrep using --config
        """
        try:
            m = hashlib.sha256()
            for c in configs:
                m.update(c.encode())
            self.payload.environment.configNamesHash = met.Sha256(m.hexdigest())
        except Exception as e:
            self.log_exception("add_configs", e)

    def add_rules(self, rules: Sequence[Rule], profile: Optional[out.Profile]) -> None:
        try:
            rules = sorted(rules, key=lambda r: r.full_hash)
            m = hashlib.sha256()
            for rule in rules:
                m.update(rule.full_hash.encode())
            self.payload.environment.rulesHash = met.Sha256(m.hexdigest())

            self.payload.performance.numRules = len(rules)
            if profile:
                # aggregate rule stats across files
                _rule_match_times: Dict[out.RuleId, float] = defaultdict(float)
                _rule_bytes_scanned: Dict[out.RuleId, int] = defaultdict(int)
                for i, rule_id in enumerate(profile.rules):
                    for target_times in profile.targets:
                        if target_times.match_times[i] > 0.0:
                            _rule_match_times[rule_id] += target_times.match_times[i]
                            _rule_bytes_scanned[rule_id] += target_times.num_bytes

                self.payload.performance.ruleStats = [
                    RuleStats(
                        ruleHash=rule.full_hash,
                        matchTime=mock_float(_rule_match_times[rule.id2]),
                        bytesScanned=mock_int(_rule_bytes_scanned[rule.id2]),
                    )
                    for rule in rules
                    # We consider only rules with match times and bytes scanned
                    # greater than 0 to avoid making the metrics too bloated.
                    if _rule_match_times[rule.id2] > 0.0
                    and _rule_bytes_scanned[rule.id2] > 0
                ]
        except Exception as e:
            self.log_exception("add_rules", e)

    def add_max_memory_bytes(self, profiling_data: Optional[out.Profile]) -> None:
        try:
            if profiling_data:
                self.payload.performance.maxMemoryBytes = (
                    profiling_data.max_memory_bytes
                )
        except Exception as e:
            self.log_exception("add_max_memory_bytes", e)

    def add_findings(self, findings: FilteredMatches) -> None:
        try:
            # Rules with 0 findings don't carry a lot of information
            # compared to rules that actually have findings. Rules with 0
            # findings also increase the size of the metrics quite
            # significantly, e.g., when the number of rules grows up to
            # magnitudes of 10k. So we filter them out in the metrics.
            self.payload.value.ruleHashesWithFindings = [
                (r.full_hash, len(f)) for r, f in findings.kept.items() if len(f) > 0
            ]
            self.payload.value.numFindings = sum(len(v) for v in findings.kept.values())
            self.payload.value.numIgnored = sum(
                len(v) for v in findings.removed.values()
            )

            # Breakdown # of findings per-product.
            _num_findings_by_product: Dict[out.Product, int] = defaultdict(int)
            for r, f in findings.kept.items():
                _num_findings_by_product[r.product] += len(f)
            self.payload.value.numFindingsByProduct = [
                (USER_FRIENDLY_PRODUCT_NAMES[p], n_findings)
                for p, n_findings in _num_findings_by_product.items()
            ]
        except Exception as e:
            self.log_exception("add_findings", e)

    def add_targets(
        self, targets: Set[TargetInfo], profile: Optional[out.Profile]
    ) -> None:
        try:
            if profile:
                self.payload.performance.fileStats = [
                    FileStats(
                        size=target_times.num_bytes,
                        numTimesScanned=mock_int(
                            len([x for x in target_times.match_times if x > 0.0])
                        ),
                        # TODO: we just have a single parse_time in target_times.parse_times
                        parseTime=mock_float(
                            max(time for time in target_times.parse_times)
                        ),
                        matchTime=mock_float(
                            sum(time for time in target_times.match_times)
                        ),
                        runTime=mock_float(target_times.run_time),
                    )
                    for target_times in profile.targets
                ]
                # Sorted by key so that variation in target order can't be
                # noticed by different ordering of file sizes.
                self.payload.performance.fileStats = sorted(
                    self.payload.performance.fileStats, key=lambda fs: fs.size
                )
            # TODO: fit the data in profile?
            total_bytes_scanned = sum(t.fpath.stat().st_size for t in targets)
            self.payload.performance.totalBytesScanned = total_bytes_scanned
            self.payload.performance.numTargets = len(targets)
        except Exception as e:
            self.log_exception("add_targets", e)

    def add_errors(self, errors: List[SemgrepError]) -> None:
        try:
            self.payload.errors.errors = [
                met.Error(error_type_string(e.type_())) for e in errors
            ]
        except Exception as e:
            self.log_exception("add_errors", e)

    def add_profiling(self, profiler: ProfileManager) -> None:
        try:
            self.payload.performance.profilingTimes = [
                (k, v) for k, v in profiler.dump_stats().items()
            ]
        except Exception as e:
            self.log_exception("add_profiling", e)

    def add_token(self, token: Optional[str]) -> None:
        try:
            self.payload.environment.isAuthenticated = bool(token)
        except Exception as e:
            self.log_exception("add_token", e)

    def add_integration_name(self, name: Optional[str]) -> None:
        try:
            self.payload.environment.integrationName = name
        except Exception as e:
            self.log_exception("add_integration_name", e)

    def add_exit_code(self, exit_code: int) -> None:
        try:
            self.payload.errors.returnCode = exit_code
        except Exception as e:
            self.log_exception("add_exit_code", e)

    def add_version(self, version: str) -> None:
        try:
            self.payload.environment.version = version
        except Exception as e:
            self.log_exception("add_version", e)

    def add_feature(self, category: LiteralString, name: str) -> None:
        try:
            self.payload.value.features.append(f"{category}/{name}")
            self.payload.value.features.sort()
        except Exception as e:
            self.log_exception("add_feature", e)

    def add_registry_url(self, url_string: str) -> None:
        try:
            path = urlparse(url_string).path
            parts = path.lstrip("/").split("/")
            if len(parts) != 2:
                return  # not a simple registry shorthand

            prefix, name = parts

            if prefix == "r":
                # we want to avoid reporting specific rules, so we do this mapping:
                # r/python -> "python"
                # r/python.flask -> "python."
                # r/python.correctness.lang => "python.."
                query_parts = name.split(".")
                dot_count = len(query_parts) - 1
                self.add_feature("registry-query", query_parts[0] + dot_count * ".")
            if prefix == "p":
                self.add_feature("ruleset", name)
        except Exception as e:
            self.log_exception("add_registry_url", e)

    def add_parse_rates(self, parse_rates: ParsingData) -> None:
        """
        Adds parse rates, grouped by language
        """
        try:
            self.payload.parse_rate = [
                (
                    str(lang),
                    ParseStat(
                        targets_parsed=data.num_targets - data.targets_with_errors,
                        num_targets=data.num_targets,
                        bytes_parsed=data.num_bytes - data.error_bytes,
                        num_bytes=data.num_bytes,
                    ),
                )
                for (lang, data) in parse_rates.get_errors_by_lang().items()
            ]
        except Exception as e:
            self.log_exception("add_parse_rates", e)

    def add_extension(
        self,
        machine_id: Optional[str],
        new_install: Optional[bool],
        session_id: Optional[str],
        version: Optional[str],
        type: Optional[str],
    ) -> None:
        try:
            self.payload.extension = Extension(
                machineId=machine_id,
                isNewAppInstall=new_install,
                sessionId=session_id,
                version=version,
                ty=type,
            )
        except Exception as e:
            self.log_exception("add_extension", e)

    def clear_mcp(self) -> None:
        self.payload.mcp = Mcp()
        self.payload.performance.totalBytesScanned = None
        self.payload.performance.numRules = None
        self.payload.environment.deployment_id = None

    def add_mcp(
        self,
        deployment_id: Optional[int],
        session_id: str,
        deployment_name: Optional[str],
        tool_name: Optional[str],
    ) -> None:
        try:
            self.payload.environment.deployment_id = deployment_id
            self.payload.mcp = Mcp(
                deployment_name=deployment_name,
                tool_name=tool_name,
                session_id=session_id,
            )
        except Exception as e:
            self.log_exception("add_mcp", e)

    def add_mcp_scan_metrics(
        self,
        results: SemgrepScanResult,
        num_lines_scanned: int,
    ) -> None:
        try:
            total_bytes_scanned = int(
                results.mcp_scan_results.get("total_bytes_scanned") or 0
            )
            rules = list(results.mcp_scan_results.get("rules") or [])
            # Fill in the some of the performance fields from the MCP scan results that we actually use
            self.payload.performance.totalBytesScanned = total_bytes_scanned
            self.payload.performance.numRules = len(rules)
            self.payload.mcp.num_skipped_rules = len(results.skipped_rules)
            self.payload.mcp.rules = rules
            self.payload.mcp.num_scanned_files = len(results.paths["scanned"])
            self.payload.mcp.num_findings = len(results.results)
            self.payload.mcp.findings = [
                (
                    finding["check_id"],
                    Finding(
                        path=finding["path"],
                        line=finding["start"]["line"],
                        col=finding["start"]["col"],
                        offset=finding["start"]["offset"],
                        severity=finding["extra"]["severity"],
                    ),
                )
                for finding in results.results
            ]
            self.payload.mcp.errors = [error["message"] for error in results.errors]
            self.payload.mcp.num_lines = num_lines_scanned
        except Exception as e:
            self.log_exception("add_mcp_scan", e)

    def add_mcp_git_info(self, git_info: Optional[dict[str, str]]) -> None:
        try:
            if git_info:
                self.payload.mcp.git_username = git_info["username"]
                self.payload.mcp.git_repo = git_info["repo"]
                self.payload.mcp.git_branch = git_info["branch"]
        except Exception as e:
            self.log_exception("add_mcp_git_info", e)

    def as_json(self) -> str:
        value = self.payload.to_json()
        return json.dumps(value, indent=2, sort_keys=True, cls=MetricsJsonEncoder)

    @property
    def is_enabled(self) -> bool:
        """
        Returns whether metrics should be sent.

        If metrics_state is:
          - auto, sends if using_registry or if logged in
          - on, sends
          - off, doesn't send
        """
        # import here to prevent circular import
        from semgrep.state import get_state

        state = get_state()

        if self.metrics_state == MetricsState.AUTO:
            # When running logged in with `semgrep ci`, configs are
            # resolved before `self.is_using_registry` is set.
            # However, these scans are still pulling from the registry
            # TODO?
            # using_app = (
            #    state.command.get_subcommand() == "ci"
            #    and state.app_session.is_authenticated
            # )
            using_app = state.app_session.is_authenticated
            return self.is_using_registry or using_app
        return self.metrics_state == MetricsState.ON

    def gather_click_params(self) -> None:
        try:
            ctx = click.get_current_context()
            if ctx is None:
                return
            for param in ctx.params:
                source = ctx.get_parameter_source(param)
                if source == click.core.ParameterSource.COMMANDLINE:
                    self.add_feature("cli-flag", param)
                if source == click.core.ParameterSource.ENVIRONMENT:
                    self.add_feature("cli-envvar", param)
                if source == click.core.ParameterSource.PROMPT:
                    self.add_feature("cli-prompt", param)
        except Exception as e:
            self.log_exception("gather_click_params", e)

    # Posting the metrics is separated out so that our tests can check for it
    # TODO it's a bit unfortunate that our tests are going to post metrics...
    def _post_metrics(self, *, user_agent: str, local_scan_id: str) -> None:
        # old: was also logging {self.as_json()}
        # alt: save it in ~/.semgrep/logs/metrics.json?
        logger.debug(f"Sending to {METRICS_ENDPOINT}")
        r = requests.post(
            METRICS_ENDPOINT,
            data=self.as_json(),
            headers={
                "Content-Type": "application/json",
                "User-Agent": user_agent,
                "X-Semgrep-Scan-ID": local_scan_id,
            },
            timeout=3,
        )
        logger.debug(f"response from {METRICS_ENDPOINT} {r.json()}")
        r.raise_for_status()

    def send(self) -> None:
        """
        Send metrics to the metrics server.

        Will if is_enabled is True
        """
        try:
            from semgrep.state import get_state  # avoiding circular import

            state = get_state()

            logger.verbose(
                f"{'Sending' if self.is_enabled else 'Not sending'} pseudonymous \
metrics since metrics are configured to {self.metrics_state.name}, \
registry usage is {self.is_using_registry}, and login status is {state.app_session.is_authenticated}"
            )

            if not self.is_enabled:
                return

            self.gather_click_params()
            self.payload.sent_at = Datetime(datetime.now().astimezone().isoformat())

            self.payload.anonymous_user_id = state.settings.get("anonymous_user_id")

            self._post_metrics(
                user_agent=str(state.app_session.user_agent),
                local_scan_id=str(state.local_scan_id),
            )
        except Exception as e:
            self.log_exception("send", e)
